import argparse
import ast
import os
import re
from collections import Counter, defaultdict
from typing import List
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import sys

sys.path.insert(-1, "/data_ext/ref_code/LocAgent")
from ast_analysis.ast_analysis_java import analyze_file_java
from ast_analysis.ast_analysis_py import analyze_file_py

VERSION = "v2.3"
NODE_TYPE_DIRECTORY = "directory"
NODE_TYPE_FILE = "file"
NODE_TYPE_CLASS = "class"
NODE_TYPE_FUNCTION = "function"
EDGE_TYPE_CONTAINS = "contains"
EDGE_TYPE_INHERITS = "inherits"
EDGE_TYPE_INVOKES = "invokes"
EDGE_TYPE_IMPORTS = "imports"

VALID_NODE_TYPES = [
    NODE_TYPE_DIRECTORY,
    NODE_TYPE_FILE,
    NODE_TYPE_CLASS,
    NODE_TYPE_FUNCTION,
]
VALID_EDGE_TYPES = [
    EDGE_TYPE_CONTAINS,
    EDGE_TYPE_INHERITS,
    EDGE_TYPE_INVOKES,
    EDGE_TYPE_IMPORTS,
]

SKIP_DIRS = [".github", ".git"]


def is_skip_dir(dirname):
    for skip_dir in SKIP_DIRS:
        if skip_dir in dirname:
            return True
    return False


def handle_edge_cases(code):
    # hard-coded edge cases
    code = code.replace("\ufeff", "")
    code = code.replace("constants.False", "_False")
    code = code.replace("constants.True", "_True")
    code = code.replace("False", "_False")
    code = code.replace("True", "_True")
    code = code.replace("DOMAIN\\username", "DOMAIN\\\\username")
    code = code.replace("Error, ", "Error as ")
    code = code.replace("Exception, ", "Exception as ")
    code = code.replace("print ", "yield ")
    pattern = r"except\s+\(([^,]+)\s+as\s+([^)]+)\):"
    # Replace 'as' with ','
    code = re.sub(pattern, r"except (\1, \2):", code)
    code = code.replace("raise AttributeError as aname", "raise AttributeError")
    return code


def find_imports(filepath, repo_path, tree=None):
    if tree is None:
        try:
            with open(filepath, "r") as file:
                tree = ast.parse(file.read(), filename=filepath)
        except:
            raise SyntaxError
        # include all imports for file
        candidates = ast.walk(tree)
    else:
        # only include top level import for classes/functions
        candidates = ast.iter_child_nodes(tree)

    imports = []
    for node in candidates:
        if isinstance(node, ast.Import):
            # Handle 'import module' and 'import module as alias'
            for alias in node.names:
                module_name = alias.name
                asname = alias.asname
                imports.append(
                    {"type": "import", "module": module_name, "alias": asname}
                )
        elif isinstance(node, ast.ImportFrom):
            # Handle 'from ... import ...' statements
            import_entities = []
            for alias in node.names:
                if alias.name == "*":
                    import_entities = [{"name": "*", "alias": None}]
                    break
                else:
                    entity_name = alias.name
                    asname = alias.asname
                    import_entities.append({"name": entity_name, "alias": asname})

            # Calculate the module name for relative imports
            if node.level == 0:
                # Absolute import
                module_name = node.module
            else:
                # Relative import
                rel_path = os.path.relpath(filepath, repo_path)
                # rel_dir = os.path.dirname(rel_path)
                package_parts = rel_path.split(os.sep)

                # Adjust for the level of relative import
                if len(package_parts) >= node.level:
                    package_parts = package_parts[: -node.level]
                else:
                    package_parts = []

                if node.module:
                    module_name = ".".join(package_parts + [node.module])
                else:
                    module_name = ".".join(package_parts)

            imports.append(
                {"type": "from", "module": module_name, "entities": import_entities}
            )
    return imports


def resolve_module(module_name, repo_path):
    """
    Resolve a module name to a file path in the repo.
    Returns the file path if found, or None if not found.
    """
    # Try to resolve as a .py file
    module_path = os.path.join(repo_path, module_name.replace(".", "/") + ".py")
    if os.path.isfile(module_path):
        return module_path

    # Try to resolve as a package (__init__.py)
    init_path = os.path.join(repo_path, module_name.replace(".", "/"), "__init__.py")
    if os.path.isfile(init_path):
        return init_path

    return None


def add_imports(root_node, imports, graph, repo_path):
    for imp in imports:
        if imp["type"] == "import":
            # Handle 'import module' statements
            module_name = imp["module"]
            module_path = resolve_module(module_name, repo_path)
            if module_path:
                imp_filename = os.path.relpath(module_path, repo_path)
                if graph.has_node(imp_filename):
                    graph.add_edge(
                        root_node,
                        imp_filename,
                        type=EDGE_TYPE_IMPORTS,
                        alias=imp["alias"],
                    )
        elif imp["type"] == "from":
            # Handle 'from module import entity' statements
            module_name = imp["module"]
            entities = imp["entities"]

            if len(entities) == 1 and entities[0]["name"] == "*":
                # Handle 'from module import *' as 'import module' statement
                module_path = resolve_module(module_name, repo_path)
                if module_path:
                    imp_filename = os.path.relpath(module_path, repo_path)
                    if graph.has_node(imp_filename):
                        graph.add_edge(
                            root_node, imp_filename, type=EDGE_TYPE_IMPORTS, alias=None
                        )
                continue  # Skip further processing for 'import *'

            for entity in entities:
                entity_name, entity_alias = entity["name"], entity["alias"]
                entity_module_name = f"{module_name}.{entity_name}"
                entity_module_path = resolve_module(entity_module_name, repo_path)
                if entity_module_path:
                    # Entity is a submodule
                    entity_filename = os.path.relpath(entity_module_path, repo_path)
                    if graph.has_node(entity_filename):
                        graph.add_edge(
                            root_node,
                            entity_filename,
                            type=EDGE_TYPE_IMPORTS,
                            alias=entity_alias,
                        )
                else:
                    # Entity might be an attribute inside the module
                    module_path = resolve_module(module_name, repo_path)
                    if module_path:
                        imp_filename = os.path.relpath(module_path, repo_path)
                        node = f"{imp_filename}:{entity_name}"
                        if graph.has_node(node):
                            graph.add_edge(
                                root_node,
                                node,
                                type=EDGE_TYPE_IMPORTS,
                                alias=entity_alias,
                            )
                        elif graph.has_node(imp_filename):
                            graph.add_edge(
                                root_node,
                                imp_filename,
                                type=EDGE_TYPE_IMPORTS,
                                alias=entity_alias,
                            )


def resolve_symlink(file_path):
    """
    Resolve the absolute path of a symbolic link.

    Args:
        file_path (str): The symbolic link file path.

    Returns:
        str: The absolute path of the target file if the file is a symbolic link.
        None: If the file is not a symbolic link.
    """
    if os.path.islink(file_path):
        # Get the relative path to the target file
        relative_target = os.readlink(file_path)
        # Get the directory of the symbolic link
        symlink_dir = os.path.dirname(os.path.dirname(file_path))
        # Combine the symlink directory with the relative target path
        absolute_target = os.path.abspath(os.path.join(symlink_dir, relative_target))
        if not os.path.exists(absolute_target):
            print(f"The target file does not exist: {absolute_target}")
            return None
        return absolute_target
    else:
        print(f"{file_path} is not a symbolic link.")
        return None


# Traverse all the Python files under repo_path, construct dependency graphs
# with node types: directory, file, class, function
def build_graph(repo_path, fuzzy_search=True, global_import=False):
    graph = nx.MultiDiGraph()
    file_nodes = {}

    ## add nodes
    graph.add_node("/", type=NODE_TYPE_DIRECTORY)
    dir_stack: List[str] = []
    dir_include_stack: List[bool] = []
    for root, _, files in os.walk(repo_path):

        # add directory nodes and edges
        dirname = os.path.relpath(root, repo_path)
        if dirname == ".":
            dirname = "/"
        elif is_skip_dir(dirname):
            continue
        else:
            graph.add_node(dirname, type=NODE_TYPE_DIRECTORY)
            parent_dirname = os.path.dirname(dirname)
            if parent_dirname == "":
                parent_dirname = "/"
            graph.add_edge(parent_dirname, dirname, type=EDGE_TYPE_CONTAINS)

        # in reverse step, remove directories that do not contain .py file
        while len(dir_stack) > 0 and not dirname.startswith(dir_stack[-1]):
            if not dir_include_stack[-1]:
                # print('remove', dir_stack[-1])
                graph.remove_node(dir_stack[-1])
            dir_stack.pop()
            dir_include_stack.pop()
        if dirname != "/":
            dir_stack.append(dirname)
            dir_include_stack.append(False)

        dir_has_py = False
        for file in files:
            if 1:
                if file.endswith(".py"):
                    dir_has_py = True
                    try:
                        file_path = os.path.join(root, file)
                        filename = os.path.relpath(file_path, repo_path)
                        if os.path.islink(file_path):
                            continue
                        else:
                            with open(file_path, "r") as f:
                                file_content = f.read()

                        graph.add_node(filename, type=NODE_TYPE_FILE, code=file_content)
                        file_nodes[filename] = file_path

                        nodes = analyze_file_py(file_path)
                    except (UnicodeDecodeError, SyntaxError):
                        # Skip the file that cannot decode or parse
                        continue
                else:
                    continue

                # add function/class nodes
                for node in nodes:
                    full_name = f'{filename}:{node["name"]}'
                    graph.add_node(
                        full_name,
                        type=node["type"],
                        code=node["code"],
                        start_line=node["start_line"],
                        end_line=node["end_line"],
                    )

                # add edges with type=contains
                graph.add_edge(dirname, filename, type=EDGE_TYPE_CONTAINS)
                for node in nodes:
                    full_name = f'{filename}:{node["name"]}'
                    name_list = node["name"].split(".")
                    if len(name_list) == 1:
                        graph.add_edge(filename, full_name, type=EDGE_TYPE_CONTAINS)
                    else:
                        parent_name = ".".join(name_list[:-1])
                        full_parent_name = f"{filename}:{parent_name}"
                        graph.add_edge(
                            full_parent_name, full_name, type=EDGE_TYPE_CONTAINS
                        )

        # keep all parent directories
        if dir_has_py:
            for i in range(len(dir_include_stack)):
                dir_include_stack[i] = True

    # check last traversed directory
    while len(dir_stack) > 0:
        if not dir_include_stack[-1]:
            graph.remove_node(dir_stack[-1])
        dir_stack.pop()
        dir_include_stack.pop()

    ## add imports edges (file -> class/function)
    for filename, filepath in file_nodes.items():
        try:
            imports = find_imports(filepath, repo_path)
        except SyntaxError:
            continue
        add_imports(filename, imports, graph, repo_path)

    global_name_dict = defaultdict(list)
    if global_import:
        for node in graph.nodes():
            node_name = node.split(":")[-1].split(".")[-1]
            global_name_dict[node_name].append(node)

    ## add edges start from class/function
    for node, attributes in graph.nodes(data=True):
        if attributes.get("type") not in [NODE_TYPE_CLASS, NODE_TYPE_FUNCTION]:
            continue

        caller_code_tree = ast.parse(graph.nodes[node]["code"])

        # construct possible callee dict (name -> node) based on graph connectivity
        callee_nodes, callee_alias = find_all_possible_callee(node, graph)
        if fuzzy_search:
            # for nodes with the same suffix, keep every nodes
            callee_name_dict = defaultdict(list)
            for callee_node in set(callee_nodes):
                callee_name = callee_node.split(":")[-1].split(".")[-1]
                callee_name_dict[callee_name].append(callee_node)
            for alias, callee_node in callee_alias.items():
                callee_name_dict[alias].append(callee_node)
        else:
            # for nodes with the same suffix, only keep the nearest node
            callee_name_dict = {
                callee_node.split(":")[-1].split(".")[-1]: callee_node
                for callee_node in callee_nodes[::-1]
            }
            callee_name_dict.update(callee_alias)

        # analysis invokes and inherits, add (top-level) imports edges (class/function -> class/function)
        if attributes.get("type") == NODE_TYPE_CLASS:
            invocations, inheritances = analyze_init(
                node, caller_code_tree, graph, repo_path
            )
        else:
            invocations = analyze_invokes(node, caller_code_tree, graph, repo_path)
            inheritances = []

        # add invokes edges (class/function -> class/function)
        for callee_name in set(invocations):
            callee_node = callee_name_dict.get(callee_name)
            if callee_node:
                if isinstance(callee_node, list):
                    for callee in callee_node:
                        graph.add_edge(node, callee, type=EDGE_TYPE_INVOKES)
                else:
                    graph.add_edge(node, callee_node, type=EDGE_TYPE_INVOKES)
            elif global_import:
                # search from global name dict
                global_fuzzy_nodes = global_name_dict.get(callee_name)
                if global_fuzzy_nodes:
                    for global_fuzzy_node in global_fuzzy_nodes:
                        graph.add_edge(node, global_fuzzy_node, type=EDGE_TYPE_INVOKES)

        # add inherits edges (class -> class)
        for callee_name in set(inheritances):
            callee_node = callee_name_dict.get(callee_name)
            if callee_node:
                if isinstance(callee_node, list):
                    for callee in callee_node:
                        graph.add_edge(node, callee, type=EDGE_TYPE_INHERITS)
                else:
                    graph.add_edge(node, callee_node, type=EDGE_TYPE_INHERITS)
            elif global_import:
                # search from global name dict
                global_fuzzy_nodes = global_name_dict.get(callee_name)
                if global_fuzzy_nodes:
                    for global_fuzzy_node in global_fuzzy_nodes:
                        graph.add_edge(node, global_fuzzy_node, type=EDGE_TYPE_INHERITS)

    return graph


def get_inner_nodes(query_node, src_node, graph):
    inner_nodes = []
    for _, dst_node, attr in graph.edges(src_node, data=True):
        if attr["type"] == EDGE_TYPE_CONTAINS and dst_node != query_node:
            inner_nodes.append(dst_node)
            if (
                graph.nodes[dst_node]["type"] == NODE_TYPE_CLASS
            ):  # only include class's inner nodes
                inner_nodes.extend(get_inner_nodes(query_node, dst_node, graph))
    return inner_nodes


def find_all_possible_callee(node, graph):
    callee_nodes, callee_alias = [], {}
    cur_node = node
    pre_node = node

    # 找到所有被调用方，一个文件里的所有函数，他们对应的函数、类、文件来源~
    def find_parent(_cur_node):
        # 找到先导节点，如果存在包含关系那么直接返回，主要是找到
        # bb.py 中 存在 A.func() 调用， 那么节点func的先导A，需要看下他在哪个文件里，最终找到aa.py
        # 看起来很难通过import直接还原
        for predecessor in graph.predecessors(_cur_node):
            for key, attr in graph.get_edge_data(predecessor, _cur_node).items():
                if attr["type"] == EDGE_TYPE_CONTAINS:
                    return predecessor

    while True:
        callee_nodes.extend(get_inner_nodes(pre_node, cur_node, graph))

        if graph.nodes[cur_node]["type"] == NODE_TYPE_FILE:

            # check recursive imported files
            # 遍历寻找所有可能的文件
            file_list = []
            file_stack = [cur_node]
            while len(file_stack) > 0:
                for _, dst_node, attr in graph.edges(file_stack.pop(), data=True):
                    if attr[
                        "type"
                    ] == EDGE_TYPE_IMPORTS and dst_node not in file_list + [cur_node]:
                        if graph.nodes[dst_node][
                            "type"
                        ] == NODE_TYPE_FILE and dst_node.endswith("__init__.py"):
                            file_list.append(dst_node)
                            file_stack.append(dst_node)

            for file in file_list:
                callee_nodes.extend(get_inner_nodes(cur_node, file, graph))
                for _, dst_node, attr in graph.edges(file, data=True):
                    if attr["type"] == EDGE_TYPE_IMPORTS:
                        if attr["alias"] is not None:
                            callee_alias[attr["alias"]] = dst_node
                        if graph.nodes[dst_node]["type"] in [
                            NODE_TYPE_FILE,
                            NODE_TYPE_CLASS,
                        ]:
                            callee_nodes.extend(get_inner_nodes(file, dst_node, graph))
                        if graph.nodes[dst_node]["type"] in [
                            NODE_TYPE_FUNCTION,
                            NODE_TYPE_CLASS,
                        ]:
                            callee_nodes.append(dst_node)

            # check imported functions and classes
            for _, dst_node, attr in graph.edges(cur_node, data=True):
                if attr["type"] == EDGE_TYPE_IMPORTS:
                    if attr["alias"] is not None:
                        callee_alias[attr["alias"]] = dst_node
                    if graph.nodes[dst_node]["type"] in [
                        NODE_TYPE_FILE,
                        NODE_TYPE_CLASS,
                    ]:
                        callee_nodes.extend(get_inner_nodes(cur_node, dst_node, graph))
                    if graph.nodes[dst_node]["type"] in [
                        NODE_TYPE_FUNCTION,
                        NODE_TYPE_CLASS,
                    ]:
                        callee_nodes.append(dst_node)

            break

        pre_node = cur_node
        cur_node = find_parent(cur_node)

    return callee_nodes, callee_alias


def analyze_init(node, code_tree, graph, repo_path):
    caller_name = node.split(":")[-1].split(".")[-1]
    file_path = os.path.join(repo_path, node.split(":")[0])

    invocations = []
    inheritances = []

    def add_invoke(func_name):
        # if func_name in callee_names:
        invocations.append(func_name)

    def add_inheritance(class_name):
        inheritances.append(class_name)

    def process_decorator_node(_decorator_node):
        if isinstance(_decorator_node, ast.Name):
            add_invoke(_decorator_node.id)
        else:
            for _sub_node in ast.walk(_decorator_node):
                if isinstance(_sub_node, ast.Call) and isinstance(
                    _sub_node.func, ast.Name
                ):
                    add_invoke(_sub_node.func.id)
                elif isinstance(_sub_node, ast.Attribute):
                    add_invoke(_sub_node.attr)

    def process_inheritance_node(_inheritance_node):
        if isinstance(_inheritance_node, ast.Attribute):
            add_inheritance(_inheritance_node.attr)
        if isinstance(_inheritance_node, ast.Name):
            add_inheritance(_inheritance_node.id)

    for ast_node in ast.walk(code_tree):
        if isinstance(ast_node, ast.ClassDef) and ast_node.name == caller_name:
            # add imports
            imports = find_imports(file_path, repo_path, tree=ast_node)
            add_imports(node, imports, graph, repo_path)

            for inheritance_node in ast_node.bases:
                process_inheritance_node(inheritance_node)

            for decorator_node in ast_node.decorator_list:
                process_decorator_node(decorator_node)

            for body_item in ast_node.body:
                if (
                    isinstance(body_item, ast.FunctionDef)
                    and body_item.name == "__init__"
                ):
                    # add imports
                    imports = find_imports(file_path, repo_path, tree=body_item)
                    add_imports(node, imports, graph, repo_path)

                    for decorator_node in body_item.decorator_list:
                        process_decorator_node(decorator_node)

                    for sub_node in ast.walk(body_item):
                        if isinstance(sub_node, ast.Call):
                            if isinstance(sub_node.func, ast.Name):  # function or class
                                add_invoke(sub_node.func.id)
                            if isinstance(
                                sub_node.func, ast.Attribute
                            ):  # member function
                                add_invoke(sub_node.func.attr)
                    break
            break

    return invocations, inheritances


def analyze_invokes(node, code_tree, graph, repo_path):
    caller_name = node.split(":")[-1].split(".")[-1]
    file_path = os.path.join(repo_path, node.split(":")[0])

    # store all the invokes found
    invocations = []

    def add_invoke(func_name):
        # if func_name in callee_names:
        invocations.append(func_name)

    def process_decorator_node(_decorator_node):
        if isinstance(_decorator_node, ast.Name):
            add_invoke(_decorator_node.id)
        else:
            for _sub_node in ast.walk(_decorator_node):
                if isinstance(_sub_node, ast.Call) and isinstance(
                    _sub_node.func, ast.Name
                ):
                    add_invoke(_sub_node.func.id)
                elif isinstance(_sub_node, ast.Attribute):
                    add_invoke(_sub_node.attr)

    def traverse_call(_node):
        for child in ast.iter_child_nodes(_node):
            if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                # Skip inner function/class definition
                continue
            elif isinstance(child, ast.Call):
                if isinstance(child.func, ast.Name):
                    add_invoke(child.func.id)
                elif isinstance(child.func, ast.Attribute):
                    add_invoke(child.func.attr)
            # Recursively traverse child nodes
            traverse_call(child)

    # Traverse AST nodes to find invokes
    for ast_node in ast.walk(code_tree):
        if (
            isinstance(ast_node, (ast.FunctionDef, ast.AsyncFunctionDef))
            and ast_node.name == caller_name
        ):
            # Add imports
            imports = find_imports(file_path, repo_path, tree=ast_node)
            add_imports(node, imports, graph, repo_path)

            # Traverse decorators
            for decorator_node in ast_node.decorator_list:
                process_decorator_node(decorator_node)

            # Traverse all the invokes nodes inside the function body, excluding inner functions and classes
            traverse_call(ast_node)
            break

    return invocations


def visualize_graph(G):
    node_types = set(nx.get_node_attributes(G, "type").values())
    node_shapes = {
        NODE_TYPE_CLASS: "o",
        NODE_TYPE_FUNCTION: "s",
        NODE_TYPE_FILE: "D",
        NODE_TYPE_DIRECTORY: "^",
    }
    node_colors = {
        NODE_TYPE_CLASS: "lightgreen",
        NODE_TYPE_FUNCTION: "lightblue",
        NODE_TYPE_FILE: "lightgrey",
        NODE_TYPE_DIRECTORY: "orange",
    }

    edge_types = set(nx.get_edge_attributes(G, "type").values())
    edge_colors = {
        EDGE_TYPE_IMPORTS: "forestgreen",
        EDGE_TYPE_CONTAINS: "skyblue",
        EDGE_TYPE_INVOKES: "magenta",
        EDGE_TYPE_INHERITS: "brown",
    }
    edge_styles = {
        EDGE_TYPE_IMPORTS: "solid",
        EDGE_TYPE_CONTAINS: "dashed",
        EDGE_TYPE_INVOKES: "dotted",
        EDGE_TYPE_INHERITS: "dashdot",
    }

    # pos = nx.spring_layout(G, k=2, iterations=50)
    pos = nx.shell_layout(G)
    # pos = nx.circular_layout(G, scale=2, center=(0, 0))

    plt.figure(figsize=(20, 20))
    plt.margins(0.15)  # Add padding around the plot

    # Draw nodes with different shapes and colors based on their type
    for ntype in node_types:
        nodelist = [n for n, d in G.nodes(data=True) if d["type"] == ntype]
        nx.draw_networkx_nodes(
            G,
            pos,
            nodelist=nodelist,
            node_shape=node_shapes[ntype],
            node_color=node_colors[ntype],
            node_size=700,
            label=ntype,
        )

    # Draw labels
    nx.draw_networkx_labels(G, pos, font_size=12, font_family="sans-serif")

    # Group edges between the same pair of nodes
    edge_groups = {}
    for u, v, key, data in G.edges(keys=True, data=True):
        if (u, v) not in edge_groups:
            edge_groups[(u, v)] = []
        edge_groups[(u, v)].append((key, data))

    # Draw edges with adjusted 'rad' values
    for (u, v), edges in edge_groups.items():
        num_edges = len(edges)
        for i, (key, data) in enumerate(edges):
            edge_type = data["type"]
            # Adjust 'rad' to spread the edges
            rad = 0.1 * (i - (num_edges - 1) / 2)
            nx.draw_networkx_edges(
                G,
                pos,
                edgelist=[(u, v)],
                edge_color=edge_colors[edge_type],
                style=edge_styles[edge_type],
                connectionstyle=f"arc3,rad={rad}",
                arrows=True,
                arrowstyle="-|>",
                arrowsize=15,
                min_source_margin=15,
                min_target_margin=15,
                width=1.5,
            )

    # Create legends for edge types and node types
    edge_legend_elements = [
        Line2D(
            [0],
            [0],
            color=edge_colors[etype],
            lw=2,
            linestyle=edge_styles[etype],
            label=etype,
        )
        for etype in edge_types
    ]
    node_legend_elements = [
        Line2D(
            [0],
            [0],
            marker=node_shapes[ntype],
            color="w",
            label=ntype,
            markerfacecolor=node_colors[ntype],
            markersize=15,
        )
        for ntype in node_types
    ]

    # Combine legends
    plt.legend(handles=edge_legend_elements + node_legend_elements, loc="upper left")
    plt.axis("off")
    plt.savefig("plots/dp_v3.png")


def traverse_directory_structure(graph, root="/"):
    def traverse(node, prefix, is_last):
        if node == root:
            print(f"{node}")
            new_prefix = ""
        else:
            connector = "└── " if is_last else "├── "
            print(f"{prefix}{connector}{node}")
            new_prefix = prefix + ("    " if is_last else "│   ")

        # Stop if the current node is a file (leaf node)
        if graph.nodes[node].get("type") == "file":
            return

        # Traverse neighbors with edge type 'contains'
        neighbors = list(graph.neighbors(node))
        for i, neighbor in enumerate(neighbors):
            for key in graph[node][neighbor]:
                if graph[node][neighbor][key].get("type") == "contains":
                    is_last_child = i == len(neighbors) - 1
                    traverse(neighbor, new_prefix, is_last_child)

    traverse(root, "", False)


def save_graph_to_excel(graph, output_dir):
    """
    Save graph nodes and edges information to Excel files.

    Args:
        graph: NetworkX MultiDiGraph
        output_dir: Directory to save Excel files
    """
    os.makedirs(output_dir, exist_ok=True)

    # 保存节点信息
    nodes_data = []
    for node, attributes in graph.nodes(data=True):
        node_info = {
            "node_id": node,
            "type": attributes.get("type", ""),
            "start_line": attributes.get("start_line", ""),
            "end_line": attributes.get("end_line", ""),
            "code_length": (
                len(attributes.get("code", "")) if attributes.get("code") else 0
            ),
            "has_code": bool(attributes.get("code")),
        }
        nodes_data.append(node_info)

    nodes_df = pd.DataFrame(nodes_data)
    nodes_excel_path = os.path.join(output_dir, "nodes_info.xlsx")
    nodes_df.to_excel(nodes_excel_path, index=False)
    print(f"Nodes information saved to {nodes_excel_path}")

    # 保存边信息
    edges_data = []
    for u, v, key, attributes in graph.edges(keys=True, data=True):
        edge_info = {
            "source": u,
            "target": v,
            "edge_key": key,
            "type": attributes.get("type", ""),
            "alias": attributes.get("alias", ""),
        }
        edges_data.append(edge_info)

    edges_df = pd.DataFrame(edges_data)
    edges_excel_path = os.path.join(output_dir, "edges_info.xlsx")
    edges_df.to_excel(edges_excel_path, index=False)
    print(f"Edges information saved to {edges_excel_path}")

    # 保存统计信息
    stats_data = {
        "Metric": ["Total Nodes", "Total Edges"]
        + [f"Nodes - {ntype}" for ntype in VALID_NODE_TYPES]
        + [f"Edges - {etype}" for etype in VALID_EDGE_TYPES],
        "Count": [
            len(graph.nodes()),
            len(graph.edges()),
        ]
        + [
            sum(1 for _, data in graph.nodes(data=True) if data.get("type") == ntype)
            for ntype in VALID_NODE_TYPES
        ]
        + [
            sum(1 for _, _, data in graph.edges(data=True) if data.get("type") == etype)
            for etype in VALID_EDGE_TYPES
        ],
    }

    stats_df = pd.DataFrame(stats_data)
    stats_excel_path = os.path.join(output_dir, "graph_statistics.xlsx")
    stats_df.to_excel(stats_excel_path, index=False)
    print(f"Graph statistics saved to {stats_excel_path}")


def main():
    import pickle

    # Generate Dependency Graph
    graph = build_graph(args.repo_path, global_import=args.global_import)
    output_path = "/data_ext/ref_code/LocAgent/demo/data/astropy_astropy.pkl"
    with open(output_path, "wb") as f:
        pickle.dump(graph, f)
    print(f"Graph saved to {output_path}")

    # 保存Excel文件
    excel_output_dir = "/data_ext/ref_code/LocAgent/demo/data/excel_output"
    save_graph_to_excel(graph, excel_output_dir)

    if args.visualize:
        visualize_graph(graph)

    inherit_list = []
    edge_types = []
    for u, v, data in graph.edges(data=True):
        if data["type"] == EDGE_TYPE_IMPORTS:
            inherit_list.append((u, v))
            # print((u, v))
        edge_types.append(data["type"])
    print(Counter(edge_types))

    node_types = []
    for node, data in graph.nodes(data=True):
        node_types.append(data["type"])
    print(Counter(node_types))

    traverse_directory_structure(graph)
    # breakpoint()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--repo_path",
        type=str,
        default="/data_ext/ref_code/LocAgent/playground/build_graph/0/astropy_astropy",
    )
    parser.add_argument("--visualize", action="store_true")
    parser.add_argument("--global_import", action="store_true")
    args = parser.parse_args()

    main()
